package com.rtsapp.server.network.protocol.crypto.server;

import com.rtsapp.server.logger.Logger;
import com.rtsapp.server.logger.LoggerFactory;
import com.rtsapp.server.network.protocol.crypto.IAppChannelInitializer;
import com.rtsapp.server.network.protocol.crypto.RC4InHandler;
import com.rtsapp.server.network.protocol.crypto.RC4OutHandler;
import com.rtsapp.server.network.protocol.crypto.KeyUtils;
import com.rtsapp.server.network.protocol.crypto.cmd.IRSAHandshakeStartCMD;
import com.rtsapp.server.network.protocol.crypto.cmd.RSAHandshakeCompeleteCMD;
import com.rtsapp.server.network.protocol.crypto.cmd.RSAHandshakeStartCMD;
import com.rtsapp.server.utils.cypto.RSA;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;

import javax.crypto.NoSuchPaddingException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.concurrent.TimeUnit;

/**
 * RSA连接握手管理
 * <p>
 * 证书交换时: 私钥加密, 公钥解密
 * * 发送时: 服务器使用服务器私钥加密，客户端使用服务器公钥进行解密
 * * 接收时: 客户端使用客户端私钥加密, 服务器使用客户端公钥进行解密
 * <p>
 * * 用于验证双方持有对方公钥
 */
public class ServerRSAHandshake {

    private static final Logger LOG = LoggerFactory.getLogger(ServerRSAHandshake.class);

    //默认握手超时毫秒数
    private static final long DEFAULT_SHAKE_TIMEOUT_MILS = 30 * 1000;

    //握手阶段输入协议的长度限制
    private static int MAX_INPUT_FRAME_LEH = 1024;
    private static int LENGTH_FIELD_OFFSET = 0;
    private static int LENTH_FIELD_LENGTH = 4;
    // 通道
    private final SocketChannel channel;
    // 原始的通道Initializer
    private final IAppChannelInitializer originChannelInitializer;
    //RSA 证书
    private final ServerRSAHandshakeKey keyRSA;
    //服务器私钥
    private final RSA serverPrivateRSA;
    //客户端公钥
    private final RSA clientPublicRSA;
    //生成的服务器通信加密key
    private final byte[] serverCryptKey;
    //生成的客户端通信加密key
    private final byte[] clientCryptKey;
    // 握手状态
    private HandshakeState state = HandshakeState.STATE_INIT_RSA_HANDERS;

    public ServerRSAHandshake(SocketChannel channel, IAppChannelInitializer originChannelInitializer, ServerRSAHandshakeKey keyRSA) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeySpecException {
        this.channel = channel;
        this.originChannelInitializer = originChannelInitializer;
        this.keyRSA = keyRSA;

        if (keyRSA != null) {
            serverPrivateRSA = new RSA();
            serverPrivateRSA.initPrivateKey(keyRSA.getN_SERVER(), keyRSA.getE_SERVER(), keyRSA.getD_SERVER());

            clientPublicRSA = new RSA();
            clientPublicRSA.initPublicKey(keyRSA.getN_CLIENT(), keyRSA.getE_CLIENT());

            serverCryptKey = KeyUtils.generateRC4CrpytKey();
            clientCryptKey = KeyUtils.generateRC4CrpytKey();
        } else {
            serverPrivateRSA = null;
            clientPublicRSA = null;

            serverCryptKey = null;
            clientCryptKey = null;
        }

        changeState(HandshakeState.STATE_INIT_RC4_HANDERS);
    }

    /**
     * 握手连接启动
     */
    public void start() throws Exception {

        //如果是加密, 初始化加密环境, 否则直接发送握手完成, 初始化原有的handler
        if (isCrypt()) {
            initRSAHandlers();
        } else {
            sendCompleteCMD();
        }
    }

    /**
     * 初始RSAHandler
     */
    public void initRSAHandlers() {


        try {
            // 顺序加入两个handler
            channel.pipeline().addLast(new LengthFieldBasedFrameDecoder(this.MAX_INPUT_FRAME_LEH, this.LENGTH_FIELD_OFFSET, this.LENTH_FIELD_LENGTH));
            channel.pipeline().addLast(new ServerRSAHandshakeHandler(this));

            //如果超时一定时间, 连接自动关闭
            channel.pipeline().addLast(new IdleStateHandler(DEFAULT_SHAKE_TIMEOUT_MILS, 0, 0, TimeUnit.MILLISECONDS));// 客户端30分钟无操作则掉线
            channel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
                @Override
                public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
                    if (IdleStateEvent.class.isAssignableFrom(evt.getClass())) {
                        IdleStateEvent event = (IdleStateEvent) evt;
                        switch (event.state()) {
                            case READER_IDLE:
                                //关闭session
                                closeUnexpectedly("[RSAHandshake] timeout, will close channel");
                                break;

                            case WRITER_IDLE:
                                break;

                            case ALL_IDLE:
                                break;

                            default:
                                break;
                        }
                    }
                }
            });


            //3. 状态到
            changeState(HandshakeState.STATE_SEND_START_CMD);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "initRSAHandlers");
        }
    }

    /**
     * 发送握手开始协议
     */
    public void sendStartCMD() {

        try {

            //.1 客户端公钥加密
            byte[] serverCryptKeyEncrpted = clientPublicRSA.encrypt(this.serverCryptKey);
            byte[] clientCryptKeyEncrpted = clientPublicRSA.encrypt(this.clientCryptKey);

            RSAHandshakeStartCMD cmd = new RSAHandshakeStartCMD();
            cmd.setServerCryptKey(serverCryptKeyEncrpted);
            cmd.setClientCryptKey(clientCryptKeyEncrpted);

            sendData(cmd);

            changeState(HandshakeState.STATE_RECV_START_ACK_CMD);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "sendStartCMD error");
        }
    }

    public void recvStartAckCMD(ByteBuf buffer) {

        try {
            RSAHandshakeStartCMD cmd = new RSAHandshakeStartCMD();
            cmd.readFromBuffer(buffer);

            //服务器私钥解密
            byte[] serverCryptKeyDecrpted =  serverPrivateRSA.decrypt(cmd.getServerCryptKey());
            byte[] clientCryptKeyEncrpted = serverPrivateRSA.decrypt(cmd.getClientCryptKey());


            //如果证书交换成功, 切换状态到发送交换
            if (KeyUtils.bytesEquals(serverCryptKey, serverCryptKeyDecrpted) &&
                    KeyUtils.bytesEquals(clientCryptKey, clientCryptKeyEncrpted)) {

                changeState(HandshakeState.STATE_SEND_COMPLETE_CMD);
                sendCompleteCMD();

            } else {
                closeUnexpectedly("cert is error");
            }
        } catch (Throwable ex) {
            closeUnexpectedly(ex, "recvStartAckCMD error");
        }
    }

    public void sendCompleteCMD() {

        try {

            RSAHandshakeCompeleteCMD cmd = new RSAHandshakeCompeleteCMD();
            cmd.setCrypt(isCrypt());

            //发送消息, 消息完成后
            sendData(cmd, new ChannelFutureListener() {

                @Override
                public void operationComplete(ChannelFuture channelFuture) throws Exception {

                    changeState(HandshakeState.STATE_INIT_RC4_HANDERS);

                    //如果加密，初始化加密通道, 否则
                    if (isCrypt()) {
                        initRC4Handlers();
                    } else {
                        initOriginHandlers();
                    }
                }

            });

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "sendCompleteCMD");
        }
    }

    public void initRC4Handlers() {

        try {
            ChannelPipeline pipeline = channel.pipeline();

            //1. 清除管道中的所有Handler
            while (pipeline.first() != null) {
                pipeline.removeFirst();
            }


//            //2. 调用原有的originChannel 初始化Handler
            originChannelInitializer.initChannel(channel);

            //3. 将RC4Handler加入到管道
            channel.pipeline().addFirst(new RC4InHandler(clientCryptKey));
            channel.pipeline().addFirst(new RC4OutHandler(serverCryptKey));

            channel.pipeline().fireChannelRegistered();
            channel.pipeline().fireChannelActive();

            //4. 状态为握手退出
            changeState(HandshakeState.STATE_EXIT);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "initRC4Handlers error");
        }
    }

    private void initOriginHandlers() throws Exception {
        originChannelInitializer.initChannel(channel);
        changeState(HandshakeState.STATE_EXIT);
    }

    public void recvData(ByteBuf buffer) {
        if (state == HandshakeState.STATE_RECV_START_ACK_CMD) {
            recvStartAckCMD(buffer);
        } else {
            closeUnexpectedly("recvData error:  state != STATE_RECV_START_ACK_CMD");
        }
    }

    /**
     * 异常关闭
     */
    public void closeUnexpectedly(String message) {
        closeUnexpectedly(null, message);
    }

    public void closeUnexpectedly(Throwable ex, String message) {
        LOG.error(ex, "[RSAHandshake] closeUnexpectedly, channel={}, message={}", channel, message);
        this.channel.close();
    }

    private boolean isCrypt() {
        return keyRSA != null;
    }

    private void changeState(HandshakeState state) {
        this.state = state;
    }

    private void sendData(IRSAHandshakeStartCMD cmd) {
        sendData(cmd, null);
    }

    private void sendData(IRSAHandshakeStartCMD cmd, ChannelFutureListener listener) {

        try {
            ByteBuf buf =  IRSAHandshakeStartCMD.outputByteBuf( cmd );

            ChannelFuture f = channel.writeAndFlush( buf );

            if (listener != null) {
                f.addListener(listener);
            }
        } catch (Throwable ex) {
            closeUnexpectedly(ex, "sendData error");
        }
    }


    /**
     * 握手状态
     * 每一个值代表连接下一步要执行的工作
     */
    enum HandshakeState {
        STATE_INIT_RSA_HANDERS,
        STATE_SEND_START_CMD,
        STATE_RECV_START_ACK_CMD,
        STATE_SEND_COMPLETE_CMD,
        STATE_INIT_RC4_HANDERS,
        STATE_EXIT;
    }




}
